"""1D arm picks up and drops balls into a basket"""
import numpy as np
from numpy.random import randint
from collections import OrderedDict
# efficient sparse matrix construction:
from scipy.sparse import dok_matrix
# efficient matrix-vector multiplication:
from scipy.sparse import csr_matrix

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from celluloid import Camera

from .simulator import Sim
from .rendering import Viewer
from .object import *
from .proposition import *

class Env(object):

    def __init__(self, name, dom_size, action_dict):
        self.name = name
        self.dom_size = dom_size
        self.obj_state = np.zeros([*dom_size, 0])
        self.color_array = np.zeros([0, 3])
        self.objects = [] # may not be necessary
        self.obj_dict = None
        self.obj_idx_dict = {} # may not be necessary
        self.props = [] # may not be necessary
        self.prop_dict = None
        self.prop_idx_dict = {} # may not be necessary

        self.P = None

        self.viewer = None

        self.action_dict = action_dict

    def get_state(self):
        state = []
        for obj in self.objects:
            # don't include obstacles
            if type(obj).mro()[1].__name__ != 'StaticObj':
                state.extend(list(obj.get_state()))
        for prop in self.props:
            state.append(int(prop.value))
        return state

    # ie [8, 6, 8, 6, 2, 2, 2, 2]
    #     x, y, x, y, p, p, p, p
    def get_full_state_space(self):
        state_space = []
        for obj in self.objects:
            # don't include obstacles
            if type(obj).mro()[1].__name__ != 'StaticObj':
                state_space.extend(obj.state_space)
        for prop in self.props:
            state_space.append(2)
        return state_space

    # state of the form [3, 2, 4, 8, 1, 0, 0, 0]
    def state_to_idx(self, state):
        state_space = self.get_full_state_space()
        return np.ravel_multi_index(tuple(state), tuple(state_space))

    def idx_to_state(self, idx):
        state_space = self.get_full_state_space()
        return np.unravel_index(idx, tuple(state_space))

    # given state of form [3, 2, 4, 8, 1, 0, 0, 0]
    # set object and prop states to match
    def set_state(self, state):
        idx = 0 # index along the whole state
        for obj in self.objects:
            # don't include obstacles
            if type(obj).mro()[1].__name__ != 'StaticObj':
                ndim = len(obj.state_space)
                obj_state = state[idx : idx+ndim]
                obj.set_state(obj_state)
                idx += ndim
        for prop in self.props:
            prop.value = bool(state[idx])
            idx += 1
        self.update_obj_state()

    # matrix/function that stores this info:
    # given a prop, returns in which states it is true
    # given a state, returns which props are true
    def make_prop_map(self):
        # find the size of the state space
        full_state_space = self.get_full_state_space()
        ss_size = np.prod(full_state_space)

        # include empty prop
        nP = len(self.props) + 1

        # define a prop map
        P = np.zeros((nP, ss_size))
        # iterate through every state
        idx = 0
        # ndindex iterates through every state in (x, y, ...)
        for state in np.ndindex(*full_state_space):
            state = list(state)
            s_idx = self.state_to_idx(state)
            self.set_state(state)

            prop_state = []
            for prop in self.props:
                try:
                    prop_state.append(int(prop.value))
                except:
                    print("JCKJDLKFJ:DLSK")
                    print(prop.value)
                if type(prop).__name__ == 'CombinedProp' and prop.value:
                    prop_state[prop.prop_idxs[0]] = 0
                    prop_state[prop.prop_idxs[1]] = 0
                    
            if 1 not in prop_state:
                prop_state.append(1)
            else:
                prop_state.append(0)

            # this is hack for the driving env; if the 'goal c' prop and left lane prop
            # are both true, make it so that only the goal c prop is true
            if prop_state[2] == 1 and type(self).__name__ == 'DriveWorldEnv':
                prop_state[3] = 0

            P[:, s_idx] = prop_state

            idx += 1
            if idx % 10000 == 0:
                print(idx)

        self.P = P

        return P

    def get_proposition(self):
        prop_state = []
        for prop in self.props:
            try:
                prop_state.append(int(prop.value))
            except:
                print("JCKJDLKFJ:DLSK")
                print(prop.value)
            if type(prop).__name__ == 'CombinedProp' and prop.value:
                prop_state[prop.prop_idxs[0]] = 0
                prop_state[prop.prop_idxs[1]] = 0
                
        if 1 not in prop_state:
            prop_state.append(1)
        else:
            prop_state.append(0)

        return np.argmax(prop_state)

    # T[s, s']
    def make_transition_function(self, plot=False):
        initial_state = self.get_state()

        # find the size of the state space
        full_state_space = self.get_full_state_space()
        ss_size = np.prod(full_state_space)

        # define a transition matrix
        T = [dok_matrix((ss_size, ss_size)) for a in self.action_dict]
        # iterate through every state
        idx = 0
        # ndindex iterates through every state in (x, y, ...)
        for state in np.ndindex(*full_state_space):
            state = list(state)
            s_idx = self.state_to_idx(state)
            # iterate through every action
            for action_name in self.action_dict:
                self.set_state(state)
                self.step(action_name)
                new_state = self.get_state()
                
                ns_idx = self.state_to_idx(new_state)

                action = self.action_dict[action_name]

                # if action_name == 'grip' and state[0] == state[1]:
                #     if state[2] == 4:
                #         print('hi')
                T[action][s_idx, ns_idx] = 1
            idx += 1
            if idx % 10000 == 0:
                print(idx)
            if plot:
                self.render(mode='fast')

        T = [t.tocsr() for t in T]

        self.set_state(initial_state)

        return T

    def make_reward_function(self):
        raise NotImplementedError()

    def step(self, action_name):
        if isinstance(action_name, list):
            action = tuple(action_name)
        elif isinstance(action_name, tuple):
            action = action_name
        else:
            action = self.action_dict[action_name]

        # 1. the objects step
        for obj in self.objects:
            obj.step(self, action)

        # 2. update the props
        for prop in self.props:
            prop.eval(self.obj_dict)
        
        # 3. update obj_state
        self.update_obj_state()                
                    
    def add_props(self, prop_dict):
        self.prop_dict = prop_dict
        for prop in prop_dict.values():
            self.add_prop(prop)

    def add_prop(self, prop):
        self.props.append(prop)
        self.prop_idx_dict[prop.name] = len(self.props) - 1

    # given an OrderedDict of objects, add each object to the env
    def add_objects(self, obj_dict):
        self.obj_dict = obj_dict
        for obj in obj_dict.values():
            self.add_object(obj)

    # add an object to the env by adding it to the object list
    # and by adding a dimension for the object to the object state
    def add_object(self, obj):
        self.objects.append(obj)
        self.obj_idx_dict[obj.name] = len(self.objects) - 1
        self.color_array = np.append(self.color_array, obj.color[None], axis=0)
        self.add_obj_state(obj)

    # add object to obj_state
    def add_obj_state(self, obj):
        # if the object is a StaticObj, its state is a 2D array
        # of the entire domain already
        if type(obj).mro()[1].__name__ == 'StaticObj':
            new_obj_array = obj.state[...,None] # need to add extra singleton dim to state
        # if the object is not a StaticObj, its state needs to be
        # converted into the representation of a 2D array
        else:
            new_obj_array = np.zeros([*self.dom_size, 1])
            new_obj_array[obj.state[0], obj.state[1], 0] = 1
        self.obj_state = np.append(self.obj_state, new_obj_array, axis=-1)

    # update object in obj_state
    def update_obj_state(self):
        for i, obj in enumerate(self.objects):
            # if StaticObj, set obj_state to match the StaticObj's state
            if type(obj).mro()[1].__name__ == 'StaticObj':
                self.obj_state[..., i] = obj.state
            # else, wipe obj_state (set it to 0) and set the obj's position to 1
            else:
                self.obj_state[..., i] = 0
                state = obj.get_state()
                self.obj_state[state[0], state[1], i] = 1

    def render(self, mode='human'):
        if self.viewer == None:
            self.viewer = Viewer(mode=mode)

        return self.viewer.render(self)

class BallDropEnv(Env):

    def __init__(self, name, dom_size, action_dict):
        super().__init__(name, dom_size, action_dict)

    def make_reward_function(self):
        initial_state = self.get_state()

        # find the size of the state space
        full_state_space = self.get_full_state_space()
        ss_size = np.prod(full_state_space)

        # define a transition matrix
        R = np.zeros((ss_size,))
        # iterate through every state
        idx = 0
        # ndindex iterates through every state in (x, y, ...)
        for state in np.ndindex(*full_state_space):
            state = list(state)
            s_idx = self.state_to_idx(state)
            self.set_state(state)

            # the goal condition
            if self.prop_dict['ainb'].value and not self.prop_dict['binb'].value:
                R[s_idx] = 10
            elif self.prop_dict['ainb'].value and self.prop_dict['binb'].value:
                R[s_idx] = 20
            idx += 1
            if idx % 10000 == 0:
                print(idx)

        self.set_state(initial_state)

        return R

    def get_state(self):
        state = []
        for obj in self.objects:
            # don't include obstacles
            if type(obj).mro()[1].__name__ != 'StaticObj':
                if obj.name == 'ball_a' or obj.name == 'ball_b':
                    state.append(obj.state[0])
                    # if ball is being held, set ball's
                    # y value to be state_space[1]-1
                    # (aka dom_size[1]-1)
                    if obj.state[2]:
                        state.append(obj.state_space[1]-1)
                    else:
                        state.append(obj.state[1])
                else:
                    state.extend(list(obj.get_state()))

        return state

    # given state of form [3, 2, 4, 8, 1, 0, 0, 0]
    # set object and prop states to match
    def set_state(self, state):
        idx = 0 # index along the whole state
        ball_being_held = False
        for obj in self.objects:
            # don't include obstacles
            if type(obj).mro()[1].__name__ != 'StaticObj':
                ndim = len(obj.state_space)
                obj_state = state[idx : idx+ndim]
                if obj.name == 'ball_a' or obj.name == 'ball_b':
                    # if ball a's y value is state_space[1]-1,
                    # then it is being held
                    if obj_state[1] == obj.state_space[1]-1:
                        obj_state[1] = self.dom_size[1]-2
                        # 3rd dimension = 1 bc it's being held
                        obj_state.append(1)
                        ball_being_held = True
                    else:
                        # 3rd dimension = 0 bc it's not being held
                        obj_state.append(0)
                obj.set_state(obj_state)
                idx += ndim
        self.obj_dict['agent'].state[2] = int(ball_being_held)
        self.update_obj_state()
        for prop in self.props:
            prop.eval(self.obj_dict)

    # ie [8, 6, 8, 6, 2, 2, 2, 2]
    #     x, y, x, y, p, p, p, p
    def get_full_state_space(self):
        state_space = []
        for obj in self.objects:
            # don't include obstacles
            if type(obj).mro()[1].__name__ != 'StaticObj':
                state_space.extend(obj.state_space)
        return state_space

class LineWorldEnv(Env):

    def __init__(self, name, dom_size, action_dict):
        super().__init__(name, dom_size, action_dict)

    def make_reward_function(self):
        initial_state = self.get_state()

        # find the size of the state space
        full_state_space = self.get_full_state_space()
        ss_size = np.prod(full_state_space)

        # define a reward function (aka a vector storing 
        # reward for each state)
        R = np.zeros((ss_size,))
        # iterate through every state
        idx = 0
        # ndindex iterates through every state in (x, y, ...)
        for state in np.ndindex(*full_state_space):
            state = list(state)
            s_idx = self.state_to_idx(state)
            self.set_state(state)

            # the goal condition
            if self.prop_dict['ona'].value or self.prop_dict['onb'].value:
                R[s_idx] = 10
            idx += 1
            if idx % 10000 == 0:
                print(idx)

        self.set_state(initial_state)

        return R

    def get_state(self):
        state = self.obj_dict['agent'].get_state()

        return state

    # given state of form [3, 2, 4, 8, 1, 0, 0, 0]
    # set object and prop states to match
    def set_state(self, state):
        self.obj_dict['agent'].set_state(state[0])
        # self.obj_dict['goal_a'].set_state(state[1])
        # self.obj_dict['goal_b'].set_state(state[2])
        self.update_obj_state()
        for prop in self.props:
            prop.eval(self.obj_dict)

    # ie [8]
    def get_full_state_space(self):
        state_space = [self.dom_size[0]]
        return state_space

class GridWorldEnv(Env):

    def __init__(self, name, dom_size, action_dict):
        super().__init__(name, dom_size, action_dict)

    def make_reward_function(self):
        initial_state = self.get_state()

        # find the size of the state space
        full_state_space = self.get_full_state_space()
        ss_size = np.prod(full_state_space)

        # define a reward function (aka a vector storing 
        # reward for each state)
        R = np.zeros((ss_size,))
        # iterate through every state
        idx = 0
        # ndindex iterates through every state in (x, y, ...)
        for state in np.ndindex(*full_state_space):
            state = list(state)
            s_idx = self.state_to_idx(state)
            self.set_state(state)

            # the goal condition
            if self.prop_dict['ona'].value or self.prop_dict['onb'].value or self.prop_dict['onc'].value:
                R[s_idx] = 10
            elif self.prop_dict['onobstacle'].value:
                R[s_idx] = -1000
            idx += 1
            if idx % 10000 == 0:
                print(idx)

        self.set_state(initial_state)

        return R

    def get_state(self):
        state = self.obj_dict['agent'].get_state()

        return state

    # given state of form [3, 2, 4, 8, 1, 0, 0, 0]
    # set object and prop states to match
    def set_state(self, state):
        self.obj_dict['agent'].set_state(state[0:2])
        # self.obj_dict['goal_a'].set_state(state[1])
        # self.obj_dict['goal_b'].set_state(state[2])
        self.update_obj_state()
        for prop in self.props:
            prop.eval(self.obj_dict)

    # ie [8]
    def get_full_state_space(self):
        state_space = self.dom_size
        return state_space

class DriveWorldEnv(Env):

    def __init__(self, name, dom_size, action_dict):
        super().__init__(name, dom_size, action_dict)

        self.option_start = None

    def make_reward_function(self):
        initial_state = self.get_state()

        # find the size of the state space
        full_state_space = self.get_full_state_space()
        ss_size = np.prod(full_state_space)

        # define a reward function (aka a vector storing 
        # reward for each state)
        R = np.zeros((ss_size,))
        # iterate through every state
        idx = 0
        # ndindex iterates through every state in (x, y, ...)
        for state in np.ndindex(*full_state_space):
            state = list(state)
            s_idx = self.state_to_idx(state)
            self.set_state(state)

            # the goal condition
            if self.prop_dict['ona'].value or self.prop_dict['onb'].value or self.prop_dict['onc'].value:
                R[s_idx] = 10
            elif self.prop_dict['onobstacle'].value:
                R[s_idx] = -1000
            idx += 1
            if idx % 10000 == 0:
                print(idx)

        self.set_state(initial_state)

        return R

    def get_state(self):
        state = self.obj_dict['agent'].get_state()

        return state

    def get_rrt_state(self):
        state = self.obj_dict['agent'].get_rrt_state()

        return state

    # given state of form [3, 2, 4, 8, 1, 0, 0, 0]
    # set object and prop states to match
    def set_state(self, state):
        self.obj_dict['agent'].set_state(state)
            
        # self.obj_dict['goal_a'].set_state(state[1])
        # self.obj_dict['goal_b'].set_state(state[2])
        self.update_obj_state()
        for prop in self.props:
            prop.eval(self.obj_dict)

    # ie [8]
    def get_full_state_space(self):
        state_space = self.dom_size # + [4] # THIS IS FOR THE 4 MANEUVERS. I need to parameterize this
        return state_space